# Copyright (c) 2022 Google LLC
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import dataclasses
import functools as ft
import inspect
import itertools as it
import sys
import warnings
import weakref
from collections.abc import Callable
from contextlib import AbstractContextManager
from typing import (
    Any,
    get_args,
    get_type_hints,
    Literal,
    NoReturn,
    overload,
    ParamSpec,
    TypeVar,
    Union,
)

from jaxtyping import AbstractArray

from ._config import config
from ._errors import AnnotationError, TypeCheckError
from ._storage import pop_shape_memo, push_shape_memo, shape_str


_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")
_T = TypeVar("_T")
# Not `TypeVar(..., type, Callable)` as else the output type of our first overload is
# just `type`, and not the particular class that is decorated.
_TypeOrCallable = TypeVar("_TypeOrCallable", bound=Union[type, Callable])


class _Sentinel:
    def __repr__(self):
        return "sentinel"


_sentinel = _Sentinel()
_tb_flag = True


def _apply_typechecker(typechecker, fn):
    """Calls `typechecker(fn)` in an isolated frame, returning the result.

    This avoids reference cycles that can otherwise occur if `typechecker` grabs
    the calling frame's locals.
    """
    return typechecker(fn)


@overload
def jaxtyped(
    *,
    typechecker=_sentinel,
) -> Callable[[_TypeOrCallable], _TypeOrCallable]: ...


@overload
def jaxtyped(fn: type[_T], *, typechecker=_sentinel) -> type[_T]: ...


@overload
def jaxtyped(
    fn: Callable[_Params, _Return], *, typechecker=_sentinel
) -> Callable[_Params, _Return]: ...


@overload
def jaxtyped(fn: Literal["context"]) -> AbstractContextManager[None]: ...


def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
    """Decorate a function with this to perform runtime type-checking of its arguments
    and return value. Decorate a dataclass to perform type-checking of its `__init__`
    method.

    !!! Example

        ```python
        # Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
        from jaxtyping import Array, Float, jaxtyped

        # Use your favourite typechecker: usually one of the two lines below.
        from typeguard import typechecked as typechecker
        from beartype import beartype as typechecker

        # Type-check a function
        @jaxtyped(typechecker=typechecker)
        def batch_outer_product(x: Float[Array, "b c1"],
                                y: Float[Array, "b c2"]
                              ) -> Float[Array, "b c1 c2"]:
            return x[:, :, None] * y[:, None, :]

        # Type-check a dataclass
        from dataclasses import dataclass

        @jaxtyped(typechecker=typechecker)
        @dataclass
        class MyDataclass:
            x: int
            y: Float[Array, "b c"]
        ```

    **Arguments:**

    - `fn`: The function or dataclass to decorate. In practice if you want to use
        dataclasses with JAX, then
        [`equinox.Module`](https://docs.kidger.site/equinox/api/module/module/) is our
        recommended approach:
        ```python
        import equinox as eqx

        @jaxtyped(typechecker=typechecker)
        class MyModule(eqx.Module):
            ...
        ```

    - `typechecker`: Keyword-only argument: the runtime type-checker to use. This should
        be a function decorator that will raise an exception if there is a type error,
        e.g.
        ```python
        @typechecker
        def f(x: int):
            pass

        f("a string is not an integer")  # this line should raise an exception
        ```
        Common choices are `typechecker=beartype.beartype` or
        `typechecker=typeguard.typechecked`. Can also be set as `typechecker=None` to
        skip automatic runtime type-checking, but still support manual `isinstance`
        checks inside the function body:
        ```python
        @jaxtyped(typechecker=None)
        def f(x):
            assert isinstance(x, Float[Array, "batch channel"])
        ```

    **Returns:**

    If `fn` is a function (including a `staticmethod`, `classmethod`, or `property`),
    then a wrapped function is returned.

    If `fn` is a dataclass, then `fn` is returned directly, and additionally its
    `__init__` method is wrapped and modified in-place.

    !!! Info "Old syntax"

        jaxtyping previously (before v0.2.24) recommended using this double-decorator
        syntax:
        ```python
        @jaxtyped
        @typechecker
        def f(...): ...
        ```
        This is still supported, but will now raise a warning recommending the
        `jaxtyped(typechecker=typechecker)` syntax discussed above. (Which will produce
        easier-to-debug error messages: under the hood, the new syntax more carefully
        manipulates the typechecker so as to determine where a type-check error arises.)

    ??? Info "Notes for advanced users"

        **Dynamic contexts:**

        Put precisely, the axis names in e.g. `Float[Array, "batch channels"]` and the
        structure names in e.g. `PyTree[int, "T"]` are all scoped to the thread-local
        dynamic context of a `jaxtyped`-wrapped function. If from within that function
        we then call another `jaxtyped`-wrapped function, then a new context is pushed
        to the stack. The axis sizes and PyTree structures of this inner function will
        then not be compared against the axis sizes and PyTree structures of the outer
        function. After the inner function returns then this inner context is popped
        from the stack, and the previous context is returned to.

        **isinstance:**

        Binding of a value against a name is done with an `isinstance` check, for
        example `isinstance(jnp.zeros((3, 4)), Float[Array, "dim1 dim2"])` will bind
        `dim1=3` and `dim2=4`. In practice these `isinstance` checks are usually done by
        the run-time typechecker `typechecker` that is supplied as an argument.

        This can also be done manually: add `isinstance` checks inside a function body
        and they will contribute to the same collection of consistency checks as are
        performed by the typechecker on the arguments and return values. (Or you can
        forgo such a typechecker altogether -- i.e. `typechecker=None` -- and only do
        your own manual `isinstance` checks.)

        Only `isinstance` checks that pass will contribute to the store of values; those
        that fail will not. As such it is safe to write e.g.
        `assert not isinstance(x, Float32[Array, "foo"])`.

        **Decoupling contexts from function calls:**

        If you would like to call a new function *without* creating a new
        dynamic context (and using the same set of axis and structure values), then
        simply do not add a `jaxtyped` decorator to your inner function, whilst
        continuing to perform type-checking in whatever way you prefer.

        Conversely, if you would like a new dynamic context *without* calling a new
        function, then in addition to the usage discussed above, `jaxtyped` also
        supports being used as a context manager, by passing it the string `"context"`:
        ```python
        with jaxtyped("context"):
            assert isinstance(x, Float[Array, "batch channel"])
        ```
        This is equivalent to placing this code inside a new function wrapped in
        `jaxtyped(typechecker=None)`. Usage like this is very rare; it's mostly only
        useful when working at the global scope.
    """

    global _tb_flag
    if _tb_flag:
        try:
            import jax._src.traceback_util as traceback_util
        except Exception:
            pass
        else:
            traceback_util.register_exclusion(__file__)
        _tb_flag = False

    # First handle the `jaxtyped("context")` usage, which is a special case.
    if fn == "context":
        if typechecker is not _sentinel:
            raise ValueError(
                "Cannot use `jaxtyped` as a context with a typechecker. That is, "
                "`with jaxtyped('context', typechecker=...):`. is not allowed. In this "
                "case the type checker does not actually do anything, as there is no "
                "function to type-check."
            )
        return _JaxtypingContext()

    # Now check that a typechecker has been explicitly declared. (Or explicitly declared
    # as not being used, via `typechecker=None`.)
    # This is needed just for backward compatibility: an undeclared typechecker
    # corresponds to the old double-decorator syntax.
    if typechecker is _sentinel:
        # This branch will also catch the easy-to-make mistake of
        # ```python
        # @jaxtyped(typechecker)
        # def foo(...):
        # ```
        # which is a bug as `typechecker` is interpreted as the function to decorate!
        warnings.warn(
            "As of jaxtyping version 0.2.24, jaxtyping now prefers the syntax\n"
            "```\n"
            "from jaxtyping import jaxtyped\n"
            "# Use your favourite typechecker: usually one of the two lines below.\n"
            "from typeguard import typechecked as typechecker\n"
            "from beartype import beartype as typechecker\n"
            "\n"
            "@jaxtyped(typechecker=typechecker)\n"
            "def foo(...):\n"
            "```\n"
            "and the old double-decorator syntax\n"
            "```\n"
            "@jaxtyped\n"
            "@typechecker\n"
            "def foo(...):\n"
            "```\n"
            "should no longer be used. (It will continue to work as it did before, but "
            "the new approach will produce more readable error messages.)\n"
            "In particular note that `typechecker` must be passed via keyword "
            "argument; the following is not valid:\n"
            "```\n"
            "@jaxtyped(typechecker)\n"
            "def foo(...):\n"
            "```\n",
            stacklevel=2,
        )
        typechecker = None

    if fn is _sentinel:
        return ft.partial(jaxtyped, typechecker=typechecker)
    elif inspect.isclass(fn):
        if dataclasses.is_dataclass(fn) and typechecker is not None:
            try:
                already_wrapped = fn.__init__.__globals__["__name__"] == __name__
            except (AttributeError, KeyError):
                pass
            else:
                if already_wrapped:
                    return fn
            fn.__init__ = jaxtyped(fn.__init__, typechecker=typechecker)
        return fn
    # It'd be lovely if we could handle arbitrary descriptors, and not just the builtin
    # ones. Unfortunately that means returning a class instance with a __get__ method,
    # and that turns out to break loads of other things. See beartype issue #211 and
    # jaxtyping issue #71.
    elif isinstance(fn, classmethod):
        return classmethod(jaxtyped(fn.__func__, typechecker=typechecker))
    elif isinstance(fn, staticmethod):
        return staticmethod(jaxtyped(fn.__func__, typechecker=typechecker))
    elif isinstance(fn, property):
        if fn.fget is None:
            fget = None
        else:
            fget = jaxtyped(fn.fget, typechecker=typechecker)
        if fn.fset is None:
            fset = None
        else:
            fset = jaxtyped(fn.fset, typechecker=typechecker)
        if fn.fdel is None:
            fdel = None
        else:
            fdel = jaxtyped(fn.fdel, typechecker=typechecker)
        return property(fget=fget, fset=fset, fdel=fdel)
    else:
        if typechecker is None:
            # Probably being used in the old style as
            # ```
            # @jaxtyped
            # @typechecker
            # def foo(x: int): ...
            # ```
            # in which case make a best-effort attempt to add shape information for any
            # type errors.

            # we want to detect generators, and ignore return annotations on them,
            # to avoid issues with O(n) typechecking trying to typecheck yielded values
            wrp = fn
            while hasattr(wrp, "__wrapped__"):
                wrp = wrp.__wrapped__

            if inspect.isgeneratorfunction(wrp) or inspect.isasyncgenfunction(wrp):
                # recursively parse all the annotations, and mark all the jaxtyping
                # annotations as not needing instance checks, while still being
                # visible as original ones for the typechecker
                def modify_annotation(ann):
                    if inspect.isclass(ann) and issubclass(ann, AbstractArray):
                        ann.make_transparent()

                    for sub_ann in get_args(ann):
                        modify_annotation(sub_ann)

                # just to make sure: check that fn has valid return annotations
                if hasattr(fn, "__annotations__") and "return" in fn.__annotations__:
                    modify_annotation(fn.__annotations__["return"])

            signature = inspect.signature(fn)

            @ft.wraps(fn)
            def wrapped_fn(*args, **kwargs):  # pyright: ignore
                __tracebackhide__ = True
                bound = signature.bind(*args, **kwargs)
                bound.apply_defaults()
                memos = push_shape_memo(bound.arguments)
                try:
                    return fn(*args, **kwargs)
                except Exception as e:
                    # add_note api is support from python 3.11+
                    if sys.version_info >= (3, 11) and _no_jaxtyping_note(e):
                        shape_info = shape_str(memos)
                        if shape_info != "":
                            msg = (
                                "The preceding error occurred within the scope of a "
                                "`jaxtyping.jaxtyped` function, and may be due to a "
                                "typecheck error. "
                            )
                            e.add_note(_jaxtyping_note_str(_spacer + msg + shape_info))
                    raise
                finally:
                    pop_shape_memo()

        else:
            # New-style
            # ```
            # @jaxtyped(typechecker=typechecker)
            # def foo(x: int): ...
            # ```
            # in which case we can do a better job reporting errors.

            full_signature = inspect.signature(fn)
            new_params = []
            for p_value in full_signature.parameters.values():
                p_annotation = _destring_annotation(p_value.annotation, fn.__globals__)
                p_value = p_value.replace(annotation=p_annotation)
                new_params.append(p_value)
            return_annotation = _destring_annotation(
                full_signature.return_annotation, fn.__globals__
            )
            full_signature = full_signature.replace(
                parameters=new_params, return_annotation=return_annotation
            )
            param_signature = full_signature.replace(return_annotation=Any)
            name = getattr(fn, "__name__", "<no name found>")
            qualname = getattr(fn, "__qualname__", "<no qualname found>")
            module = getattr(fn, "__module__", "<generated_by_jaxtyping>")

            # Use the same name so that typeguard warnings look correct.
            full_fn, output_name = _make_fn_with_signature(
                name, qualname, module, full_signature, output=True
            )
            param_fn = _make_fn_with_signature(
                name, qualname, module, param_signature, output=False
            )
            full_fn = _apply_typechecker(typechecker, full_fn)
            param_fn = _apply_typechecker(typechecker, param_fn)

            def wrapped_fn_impl(args, kwargs, bound, memos):
                __tracebackhide__ = True
                # First type-check just the parameters before the function is
                # called.
                try:
                    param_fn(*args, **kwargs)
                except AnnotationError:
                    raise
                except Exception:
                    try:
                        argmsg = _get_problem_arg(
                            param_signature,
                            args,
                            kwargs,
                            bound.arguments,
                            module,
                            typechecker,
                        )
                    except TypeCheckError as e:
                        argmsg = str(e)
                        try:
                            module_name = fn.__module__
                            qualname = fn.__qualname__
                        except AttributeError:
                            module_name = fn.__class__.__module__
                            qualname = fn.__class__.__qualname__
                        param_values = _pformat(bound.arguments, short_self=True)
                        param_hints = _remove_typing(param_signature)
                        msg = (
                            "Type-check error whilst checking the parameters of "
                            f"{module_name}.{qualname}.{argmsg}\n"
                            "----------------------\n"
                            f"Called with parameters: {param_values}\n"
                            f"Parameter annotations: {param_hints}.\n"
                            + shape_str(memos)
                        )
                        if config.jaxtyping_remove_typechecker_stack:
                            raise TypeCheckError(msg) from None
                        else:
                            raise TypeCheckError(msg) from e

                # Actually call the function.
                out = fn(*args, **kwargs)

                if full_signature.return_annotation is not inspect.Signature.empty:
                    # Now type-check the return value. We need to include the
                    # parameters in the type-checking here in case there are any
                    # type variables shared across the parameters and return.
                    #
                    # Incidentally this does mean that if `fn` mutates its arguments
                    # so that they no longer satisfy their type annotations, this
                    # will throw an error here. But that's like, super weird, so
                    # don't do that. An error in that scenario is probably still
                    # desirable.
                    #
                    # There is a small performance concern here when used in
                    # non-jit'd contexts, like PyTorch, due to the duplicate
                    # checking of the parameters. Unfortunately there doesn't seem
                    # to be a way around that, so c'est la vie.
                    kwargs[output_name] = out
                    try:
                        full_fn(*args, **kwargs)
                    except AnnotationError:
                        raise
                    except Exception as e:
                        try:
                            module_name = fn.__module__
                            qualname = fn.__qualname__
                        except AttributeError:
                            module_name = fn.__class__.__module__
                            qualname = fn.__class__.__qualname__
                        param_values = _pformat(bound.arguments, short_self=True)
                        return_value = _pformat(out, short_self=False)
                        param_hints = _remove_typing(param_signature)
                        return_hint = _remove_typing(full_signature.return_annotation)
                        if return_hint.startswith("<class '") and return_hint.endswith(
                            "'>"
                        ):
                            return_hint = return_hint[8:-2]
                        msg = (
                            "Type-check error whilst checking the return value "
                            f"of {module_name}.{qualname}.\n"
                            f"Actual value: {return_value}\n"
                            f"Expected type: {return_hint}.\n"
                            "----------------------\n"
                            f"Called with parameters: {param_values}\n"
                            f"Parameter annotations: {param_hints}.\n"
                            + shape_str(memos)
                        )
                        if config.jaxtyping_remove_typechecker_stack:
                            raise TypeCheckError(msg) from None
                        else:
                            raise TypeCheckError(msg) from e

                return out

            wrapped_fn_holder = []  # Avoids introducing a reference cycle.

            @ft.wraps(fn)
            def wrapped_fn(*args, **kwargs):
                __tracebackhide__ = True

                if (
                    config.jaxtyping_disable
                    or getattr(fn, "__no_type_check__", False)
                    or getattr(wrapped_fn_holder[0](), "__no_type_check__", False)
                ):
                    return fn(*args, **kwargs)

                # Raise bind-time errors before we do any shape analysis. (I.e. skip
                # the pointless jaxtyping information for a non-typechecking failure.)
                bound = param_signature.bind(*args, **kwargs)
                bound.apply_defaults()

                memos = push_shape_memo(bound.arguments)
                try:
                    # Put this in a separate frame to make debugging easier, without
                    # just always ending up on the `pop_shape_memo` line below.
                    return wrapped_fn_impl(args, kwargs, bound, memos)
                finally:
                    pop_shape_memo()

            wrapped_fn_holder.append(weakref.ref(wrapped_fn))

        return wrapped_fn


class _JaxtypingContext:
    def __enter__(self):
        push_shape_memo({})

    def __exit__(self, exc_type, exc_value, exc_tb):
        pop_shape_memo()


def _destring_annotation(ann, ns):
    # Best-effort attempt to destringify annotations.
    def fn(value: ann):
        del value

    try:
        hints = get_type_hints(fn, ns, include_extras=True)
    except Exception:
        # Not just `NameError` but also e.g. `ValueError` in case we have e.g.
        # 'Float[Foo, "*foo *bar"]' and raise  from having multiple variadic
        # arguments. Sometimes this can still be useful to use for human
        # documentation purposes.
        return Any
    else:
        return hints["value"]


def _make_fn_with_signature(
    name: str, qualname: str, module: str, signature: inspect.Signature, output: bool
):
    """Dynamically creates a function `fn` with name `name` and signature `signature`.

    If `output=True` then `fn` will consume an additional keyword-only argument (in
    addition to the provided signature), and will directly return this argument. In this
    case the returned value from `_make_fn_with_signature` is a 2-tuple of `(fn, name)`,
    where `fn` is the generated function, and `name` is the name of this extra argument.

    If `output=False` then `fn` will just have a single `pass` statement, and the
    returned value from `_make_fn_with_signature` will just be `fn`.

    ---

    Note that this function operates by dynamically creating and eval'ing a string, not
    simply by assigning `__signature__` and `__annotations__`. The latter is enough for
    typeguard (at least v2), but does not work with beartype (at least v16).
    """
    pos = []
    pos_or_key = []
    varpos = []
    key = []
    varkey = []
    for p in signature.parameters.values():
        if p.kind == inspect.Parameter.POSITIONAL_ONLY:
            pos.append(p)
        elif p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
            pos_or_key.append(p)
        elif p.kind == inspect.Parameter.VAR_POSITIONAL:
            varpos.append(p)
        elif p.kind == inspect.Parameter.KEYWORD_ONLY:
            key.append(p)
        elif p.kind == inspect.Parameter.VAR_KEYWORD:
            varkey.append(p)
        else:
            assert False

    param_names = frozenset(signature.parameters.keys())
    if output:
        output_name = _gensym(param_names, prefix="ret")
        outstr = "return " + output_name
        param_names = param_names | frozenset({output_name})
        key.append(inspect.Parameter(output_name, kind=inspect.Parameter.KEYWORD_ONLY))
    else:
        outstr = "pass"

    scope = {name: None}
    name_to_annotation = {}
    name_to_default = {}
    param_triples = (
        (p.name, p.annotation, p.default) for p in signature.parameters.values()
    )
    if output:
        triples = it.chain(
            param_triples,
            [
                ("return", signature.return_annotation, inspect.Signature.empty),
                (output_name, Any, inspect.Signature.empty),
            ],
        )
    else:
        triples = it.chain(
            param_triples,
            [("return", signature.return_annotation, inspect.Signature.empty)],
        )
    for p_name, p_annotation, p_default in triples:
        annotation_name = _gensym(frozenset(scope.keys()) | param_names, prefix="T")
        name_to_annotation[p_name] = annotation_name
        if p_annotation is inspect.Signature.empty or isinstance(p_annotation, str):
            # We don't handle partially-stringified local annotations, e.g.
            # `type["Foo"]` for some local type `Foo`. There is no real way to detect
            # those, and they will probably just error out for the user later.
            scope[annotation_name] = Any
        else:
            scope[annotation_name] = p_annotation
        default_name = _gensym(frozenset(scope.keys()) | param_names, prefix="default")
        name_to_default[p_name] = default_name
        scope[default_name] = p_default

    argstr_pieces = []
    if len(pos) > 0:
        for p in pos:
            argstr_pieces.append(_make_argpiece(p, name_to_annotation, name_to_default))
        argstr_pieces.append("/")
    if len(pos_or_key) > 0:
        for p in pos_or_key:
            argstr_pieces.append(_make_argpiece(p, name_to_annotation, name_to_default))
    if len(varpos) == 1:
        [p] = varpos
        argstr_pieces.append(
            "*" + _make_argpiece(p, name_to_annotation, name_to_default)
        )
    else:
        assert len(varpos) == 0
        if len(key) > 0:
            argstr_pieces.append("*")
    if len(key) > 0:
        for p in key:
            argstr_pieces.append(_make_argpiece(p, name_to_annotation, name_to_default))
    if len(varkey) == 1:
        [p] = varkey
        argstr_pieces.append(
            "**" + _make_argpiece(p, name_to_annotation, name_to_default)
        )
    else:
        assert len(varkey) == 0
    argstr = ", ".join(argstr_pieces)

    if signature.return_annotation is inspect.Signature.empty:
        retstr = ""
    else:
        retstr = f"-> {name_to_annotation['return']}"

    fnstr = f"def {name}({argstr}){retstr}:\n    {outstr}"
    exec(fnstr, scope)
    fn = scope[name]
    del scope[name]  # Avoids introducing a reference cycle.
    fn.__module__ = module
    fn.__qualname__ = qualname
    assert fn is not None
    if output:
        return fn, output_name
    else:
        return fn


def _gensym(names: frozenset[str], prefix: str) -> str:
    assert prefix.isidentifier()
    output_index = 0
    output_name = prefix + str(output_index)
    while output_name in names:
        output_index += 1
        output_name = prefix + str(output_index)
    assert output_name.isidentifier()
    return output_name


def _make_argpiece(p, name_to_annotation, name_to_default):
    if p.default is inspect.Signature.empty:
        return f"{p.name}: {name_to_annotation[p.name]}"
    else:
        return f"{p.name}: {name_to_annotation[p.name]} = {name_to_default[p.name]}"


def _get_problem_arg(
    param_signature: inspect.Signature, args, kwargs, arguments, module, typechecker
) -> NoReturn:
    """Determines which argument was likely to be the problematic one responsible for
    raising a type-check error.

    It returns the result by raising an Exception: you should grab this and extract the
    string out of it. We do this, rather than just returning the value, to aid debugging
    by making it possible to walk down the stack to the issue.
    """
    # No performance concerns, as this is only used when we're about to raise an error
    # anyway.
    for keep_name in param_signature.parameters.keys():
        new_parameters = []
        keep_annotation = sentinel = object()
        for p_name, p in param_signature.parameters.items():
            if p_name == keep_name:
                new_parameters.append(
                    inspect.Parameter(
                        p.name, p.kind, default=p.default, annotation=p.annotation
                    )
                )
                assert keep_annotation is sentinel
                keep_annotation = _remove_typing(p.annotation)
            else:
                new_parameters.append(
                    inspect.Parameter(p.name, p.kind, default=p.default)
                )
        assert keep_annotation is not sentinel
        new_signature = inspect.Signature(new_parameters)
        fn = _make_fn_with_signature(
            "check_single_arg", "check_single_arg", module, new_signature, output=False
        )
        fn = _apply_typechecker(
            typechecker, fn
        )  # but no `jaxtyped`; keep the same environment.
        try:
            fn(*args, **kwargs)
        except Exception as e:
            keep_value = _pformat(arguments[keep_name], short_self=False)
            raise TypeCheckError(
                f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n"
                f"Actual value: {keep_value}\n"
                f"Expected type: {keep_annotation}."
            ) from e
    else:
        # Could not localise the problem to a single argument -- probably due to
        # e.g. a mismatched typevar, which each individual argument is okay with.
        raise TypeCheckError("")


def _remove_typing(x):
    x = str(x)
    x = x.replace(" jaxtyping.", " ")
    x = x.replace("[jaxtyping.", "[")
    x = x.replace("'jaxtyping.", "'")
    x = x.replace(" typing.", " ")
    x = x.replace("[typing.", "[")
    x = x.replace("'typing.", "'")
    return x


def _pformat(x, short_self: bool):
    # No performance concerns from delayed imports -- this is only used when we're about
    # to raise an error anyway.
    try:
        # If we can, use `eqx.tree_pformat`, which wraps `wadler_lindig.pformat` with
        # understanding of a few other JAX-specific things.
        import equinox as eqx

        pformat = eqx.tree_pformat

        if short_self:
            try:
                self = x["self"]
            except KeyError:
                pass
            else:
                is_self = lambda y: y is self
                pformat = ft.partial(pformat, truncate_leaf=is_self)
    except Exception:
        # Failing that fall back to `wadler_lindig.pformat` directly.
        import wadler_lindig

        pformat = wadler_lindig.pformat

        if short_self:
            try:
                self = x["self"]
            except KeyError:
                pass
            else:

                def custom(obj):
                    if obj is self:
                        return wadler_lindig.TextDoc(f"{type(obj).__name__}(...)")

                pformat = ft.partial(pformat, custom=custom)
    try:
        return pformat(x)
    except Exception:
        return f"<Exception raised when pretty-formatting object of type {type(x)}.>"


class _jaxtyping_note_str(str):
    """Used with `_no_jaxtyping_note` to flag that a note came from jaxtyping."""


def _no_jaxtyping_note(e: Exception) -> bool:
    """Checks if any of the exception's notes are from jaxtyping."""
    try:
        notes = e.__notes__
    except AttributeError:
        return True
    else:
        for note in notes:
            if isinstance(note, _jaxtyping_note_str):
                return False
        return True


_spacer = "--------------------\n"
